# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import abc
import pickle
from pathlib import Path
import warnings
import pandas as pd

from typing import Tuple, Union, List, Dict

from qlib.data import D
from qlib.utils import load_dataset, init_instance_by_config, time_to_slc_point
from qlib.log import get_module_logger
from qlib.utils.serial import Serializable


class DataLoader(abc.ABC):
    """
    DataLoader is designed for loading raw data from original data source.
    """

    @abc.abstractmethod
    def load(self, instruments, start_time=None, end_time=None) -> pd.DataFrame:
        """
        load the data as pd.DataFrame.

        Example of the data (The multi-index of the columns is optional.):

            .. code-block:: text

                                        feature                                                             label
                                        $close     $volume     Ref($close, 1)  Mean($close, 3)  $high-$low  LABEL0
                datetime    instrument
                2010-01-04  SH600000    81.807068  17145150.0       83.737389        83.016739    2.741058  0.0032
                            SH600004    13.313329  11800983.0       13.313329        13.317701    0.183632  0.0042
                            SH600005    37.796539  12231662.0       38.258602        37.919757    0.970325  0.0289


        Parameters
        ----------
        instruments : str or dict
            it can either be the market name or the config file of instruments generated by InstrumentProvider.
            If the value of instruments is None, it means that no filtering is done.
        start_time : str
            start of the time range.
        end_time : str
            end of the time range.

        Returns
        -------
        pd.DataFrame:
            data load from the under layer source

        Raise
        -----
        KeyError:
            if the instruments filter is not supported, raise KeyError
        """


class DLWParser(DataLoader):
    """
    (D)ata(L)oader (W)ith (P)arser for features and names

    Extracting this class so that QlibDataLoader and other dataloaders(such as QdbDataLoader) can share the fields.
    """

    def __init__(self, config: Union[list, tuple, dict]):
        """
        Parameters
        ----------
        config : Union[list, tuple, dict]
            Config will be used to describe the fields and column names

            .. code-block::

                <config> := {
                    "group_name1": <fields_info1>
                    "group_name2": <fields_info2>
                }
                or
                <config> := <fields_info>

                <fields_info> := ["expr", ...] | (["expr", ...], ["col_name", ...])
                # NOTE: list or tuple will be treated as the things when parsing
        """
        self.is_group = isinstance(config, dict)

        if self.is_group:
            self.fields = {grp: self._parse_fields_info(fields_info) for grp, fields_info in config.items()}
        else:
            self.fields = self._parse_fields_info(config)

    def _parse_fields_info(self, fields_info: Union[list, tuple]) -> Tuple[list, list]:
        if len(fields_info) == 0:
            raise ValueError("The size of fields must be greater than 0")

        if not isinstance(fields_info, (list, tuple)):
            raise TypeError("Unsupported type")

        if isinstance(fields_info[0], str):
            exprs = names = fields_info
        elif isinstance(fields_info[0], (list, tuple)):
            exprs, names = fields_info
        else:
            raise NotImplementedError(f"This type of input is not supported")
        return exprs, names

    @abc.abstractmethod
    def load_group_df(
        self,
        instruments,
        exprs: list,
        names: list,
        start_time: Union[str, pd.Timestamp] = None,
        end_time: Union[str, pd.Timestamp] = None,
        gp_name: str = None,
    ) -> pd.DataFrame:
        """
        load the dataframe for specific group

        Parameters
        ----------
        instruments :
            the instruments.
        exprs : list
            the expressions to describe the content of the data.
        names : list
            the name of the data.

        Returns
        -------
        pd.DataFrame:
            the queried dataframe.
        """

    def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
        if self.is_group:
            df = pd.concat(
                {
                    grp: self.load_group_df(instruments, exprs, names, start_time, end_time, grp)
                    for grp, (exprs, names) in self.fields.items()
                },
                axis=1,
            )
        else:
            exprs, names = self.fields
            df = self.load_group_df(instruments, exprs, names, start_time, end_time)
        return df


class QlibDataLoader(DLWParser):
    """Same as QlibDataLoader. The fields can be define by config"""

    def __init__(
        self,
        config: Tuple[list, tuple, dict],
        filter_pipe: List = None,
        swap_level: bool = True,
        freq: Union[str, dict] = "day",
        inst_processors: Union[dict, list] = None,
    ):
        """
        Parameters
        ----------
        config : Tuple[list, tuple, dict]
            Please refer to the doc of DLWParser
        filter_pipe :
            Filter pipe for the instruments
        swap_level :
            Whether to swap level of MultiIndex
        freq:  dict or str
            If type(config) == dict and type(freq) == str, load config data using freq.
            If type(config) == dict and type(freq) == dict, load config[<group_name>] data using freq[<group_name>]
        inst_processors: dict | list
            If inst_processors is not None and type(config) == dict; load config[<group_name>] data using inst_processors[<group_name>]
            If inst_processors is a list, then it will be applied to all groups.
        """
        self.filter_pipe = filter_pipe
        self.swap_level = swap_level
        self.freq = freq

        # sample
        self.inst_processors = inst_processors if inst_processors is not None else {}
        assert isinstance(
            self.inst_processors, (dict, list)
        ), f"inst_processors(={self.inst_processors}) must be dict or list"

        super().__init__(config)

        if self.is_group:
            # check sample config
            if isinstance(freq, dict):
                for _gp in config.keys():
                    if _gp not in freq:
                        raise ValueError(f"freq(={freq}) missing group(={_gp})")
                assert (
                    self.inst_processors
                ), f"freq(={self.freq}), inst_processors(={self.inst_processors}) cannot be None/empty"

    def load_group_df(
        self,
        instruments,
        exprs: list,
        names: list,
        start_time: Union[str, pd.Timestamp] = None,
        end_time: Union[str, pd.Timestamp] = None,
        gp_name: str = None,
    ) -> pd.DataFrame:
        if instruments is None:
            warnings.warn("`instruments` is not set, will load all stocks")
            instruments = "all"
        if isinstance(instruments, str):
            instruments = D.instruments(instruments, filter_pipe=self.filter_pipe)
        elif self.filter_pipe is not None:
            warnings.warn("`filter_pipe` is not None, but it will not be used with `instruments` as list")

        freq = self.freq[gp_name] if isinstance(self.freq, dict) else self.freq
        inst_processors = (
            self.inst_processors if isinstance(self.inst_processors, list) else self.inst_processors.get(gp_name, [])
        )
        df = D.features(instruments, exprs, start_time, end_time, freq=freq, inst_processors=inst_processors)
        df.columns = names
        if self.swap_level:
            df = df.swaplevel().sort_index()  # NOTE: if swaplevel, return <datetime, instrument>
        return df


class StaticDataLoader(DataLoader, Serializable):
    """
    DataLoader that supports loading data from file or as provided.
    """

    include_attr = ["_config"]

    def __init__(self, config: Union[dict, str, pd.DataFrame], join="outer"):
        """
        Parameters
        ----------
        config : dict
            {fields_group: <path or object>}
        join : str
            How to align different dataframes
        """
        self._config = config  # using "_" to avoid confliction with the method `config` of Serializable
        self.join = join
        self._data = None

    def __getstate__(self) -> dict:
        # avoid pickling `self._data`
        return {k: v for k, v in self.__dict__.items() if not k.startswith("_")}

    def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
        self._maybe_load_raw_data()

        # 1) Filter by instruments
        if instruments is None:
            df = self._data
        else:
            df = self._data.loc(axis=0)[:, instruments]

        # 2) Filter by Datetime
        if start_time is None and end_time is None:
            return df  # NOTE: avoid copy by loc
        # pd.Timestamp(None) == NaT, use NaT as index can not fetch correct thing, so do not change None.
        start_time = time_to_slc_point(start_time)
        end_time = time_to_slc_point(end_time)
        return df.loc[start_time:end_time]

    def _maybe_load_raw_data(self):
        if self._data is not None:
            return
        if isinstance(self._config, dict):
            self._data = pd.concat(
                {fields_group: load_dataset(path_or_obj) for fields_group, path_or_obj in self._config.items()},
                axis=1,
                join=self.join,
            )
            self._data.sort_index(inplace=True)
        elif isinstance(self._config, (str, Path)):
            if str(self._config).strip().endswith(".parquet"):
                self._data = pd.read_parquet(self._config, engine="pyarrow")
            else:
                with Path(self._config).open("rb") as f:
                    self._data = pickle.load(f)
        elif isinstance(self._config, pd.DataFrame):
            self._data = self._config


class NestedDataLoader(DataLoader):
    """
    We have multiple DataLoader, we can use this class to combine them.
    """

    def __init__(self, dataloader_l: List[Dict], join="left") -> None:
        """

        Parameters
        ----------
        dataloader_l : list[dict]
            A list of dataloader, for exmaple

            .. code-block:: python

                nd = NestedDataLoader(
                    dataloader_l=[
                        {
                            "class": "qlib.contrib.data.loader.Alpha158DL",
                        }, {
                            "class": "qlib.contrib.data.loader.Alpha360DL",
                            "kwargs": {
                                "config": {
                                    "label": ( ["Ref($close, -2)/Ref($close, -1) - 1"], ["LABEL0"])
                                }
                            }
                        }
                    ]
                )
        join :
            it will pass to pd.concat when merging it.
        """
        super().__init__()
        self.data_loader_l = [
            (dl if isinstance(dl, DataLoader) else init_instance_by_config(dl)) for dl in dataloader_l
        ]
        self.join = join

    def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
        df_full = None
        for dl in self.data_loader_l:
            try:
                df_current = dl.load(instruments, start_time, end_time)
            except KeyError:
                warnings.warn(
                    "If the value of `instruments` cannot be processed, it will set instruments to None to get all the data."
                )
                df_current = dl.load(instruments=None, start_time=start_time, end_time=end_time)
            if df_full is None:
                df_full = df_current
            else:
                current_columns = df_current.columns.tolist()
                full_columns = df_full.columns.tolist()
                columns_to_drop = [col for col in current_columns if col in full_columns]
                df_full.drop(columns=columns_to_drop, inplace=True)
                df_full = pd.merge(df_full, df_current, left_index=True, right_index=True, how=self.join)
        return df_full.sort_index(axis=1)


class DataLoaderDH(DataLoader):
    """DataLoaderDH
    DataLoader based on (D)ata (H)andler
    It is designed to load multiple data from data handler
    - If you just want to load data from single datahandler, you can write them in single data handler

    TODO: What make this module not that easy to use.

    - For online scenario

        - The underlayer data handler should be configured. But data loader doesn't provide such interface & hook.
    """

    def __init__(self, handler_config: dict, fetch_kwargs: dict = {}, is_group=False):
        """
        Parameters
        ----------
        handler_config : dict
            handler_config will be used to describe the handlers

            .. code-block::

                <handler_config> := {
                    "group_name1": <handler>
                    "group_name2": <handler>
                }
                or
                <handler_config> := <handler>
                <handler> := DataHandler Instance | DataHandler Config

        fetch_kwargs : dict
            fetch_kwargs will be used to describe the different arguments of fetch method, such as col_set, squeeze, data_key, etc.

        is_group: bool
            is_group will be used to describe whether the key of handler_config is group

        """
        from qlib.data.dataset.handler import DataHandler  # pylint: disable=C0415

        if is_group:
            self.handlers = {
                grp: init_instance_by_config(config, accept_types=DataHandler) for grp, config in handler_config.items()
            }
        else:
            self.handlers = init_instance_by_config(handler_config, accept_types=DataHandler)

        self.is_group = is_group
        self.fetch_kwargs = {"col_set": DataHandler.CS_RAW}
        self.fetch_kwargs.update(fetch_kwargs)

    def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
        if instruments is not None:
            get_module_logger(self.__class__.__name__).warning(f"instruments[{instruments}] is ignored")

        if self.is_group:
            df = pd.concat(
                {
                    grp: dh.fetch(selector=slice(start_time, end_time), level="datetime", **self.fetch_kwargs)
                    for grp, dh in self.handlers.items()
                },
                axis=1,
            )
        else:
            df = self.handlers.fetch(selector=slice(start_time, end_time), level="datetime", **self.fetch_kwargs)
        return df
